Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gemma2 attention scale #694

Merged
merged 2 commits into from
Aug 11, 2024
Merged

Conversation

mntss
Copy link
Contributor

@mntss mntss commented Aug 7, 2024

Description

Current configuration uses incorrect attention scale. According to deepmind implementation the 2b and 9b versions use sqrt(d_head). This is the default scale in TL.
https://github.com/google-deepmind/gemma/blob/a0504162f99a1c238efb37b8197e711c0f3808fd/gemma/transformer.py#L152-L174

Fixes #693

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Screenshots

Please attach before and after screenshots of the change if applicable.

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@neelnanda-io
Copy link
Collaborator

Huh, you seem to be correct my bad. I'm not sure why my sanity checks didn't show an enormous divergence
Can you check if the attention patterns are notably closer to HuggingFace after this change?

@mntss
Copy link
Contributor Author

mntss commented Aug 7, 2024

HF implementation uses SPDA so I'd need to switch the attention implementation to access the pattern. I compared the output instead (hook_z) across layers

image

@ccp123456
Copy link

This question is very helpful for me and I am also being confused by this question! How do I adjust the Transformerlens' Gemma to ensure that the Transformerlens Gemma gives the same results as the HF? Which file should I add the above code to?

@bryce13950 bryce13950 changed the base branch from main to dev August 11, 2024 22:59
@bryce13950 bryce13950 changed the base branch from dev to main August 11, 2024 23:00
@bryce13950 bryce13950 changed the base branch from main to dev August 11, 2024 23:00
@bryce13950 bryce13950 merged commit e30f96b into TransformerLensOrg:dev Aug 11, 2024
11 checks passed
@bryce13950
Copy link
Collaborator

@mntss Thank you very much for this! @ccp123456 this will be put into a release relatively quickly, so you should be able to make use of this right away. The question of 100% accuracy is a little bit more complicated, and making sure Gemma models are 100% accurate is a pretty high priority task at the moment. If you interested in knowing more, DM me on slack. (If you are not on the slack channel, let me know and I can give you access.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug Report] Gemma-2-2b-it output logit doesn't match with huggingface
4 participants